{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e1a1c491",
   "metadata": {},
   "source": [
    "# Free Generation design\n",
    "\n",
    "This notebook demonstrates the Free Generation design task from the paper [Language models generalize beyond natural proteins\n",
    "](https://www.biorxiv.org/content/10.1101/2022.12.21.521521v1).\n",
    "\n",
    "Given an input sequence length, the design code activates the language model and the linear projection structure head alternatingly to generate a likely sequence and structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f76a438-3afd-4096-ba36-eb9bee90f08a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First install additional dependencies\n",
    "!pip install -r additional_requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59549997",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import os\n",
    "import time\n",
    "import hydra\n",
    "import py3Dmol\n",
    "from lm_design import Designer\n",
    "\n",
    "# Params\n",
    "seed = 0  # Use different seeds to get different sequence results\n",
    "free_generation_length = 100\n",
    "TASK = \"free_generation\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e27bf90",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load hydra config from config.yaml\n",
    "with hydra.initialize_config_module(config_module=\"conf\"):\n",
    "    cfg = hydra.compose(\n",
    "        config_name=\"config\", \n",
    "        overrides=[\n",
    "            f\"task={TASK}\", \n",
    "            f\"seed={seed}\", \n",
    "            f\"free_generation_length={free_generation_length}\", \n",
    "            # 'tasks.free_generation.num_iter=100'  # DEBUG - use a smaller number of iterations\n",
    "        ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a95fbfb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a designer from configuration\n",
    "des = Designer(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e92e577b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Run the designer\n",
    "start_time = time.time()\n",
    "des.run_from_cfg()    \n",
    "print(\"finished after %s hours\", (time.time() - start_time) / 3600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb4d9456",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Output seq:\", des.output_seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12229064",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fold output with ESMFold API\n",
    "output_seq = des.output_seq\n",
    "# Fold with api:\n",
    "#  curl -X POST --data \"GENGEIPLEIRATTGAEVDTRAVTAVEMTEGTLGIFRLPEEDYTALENFRYNRVAGENWKPASTVIYVGGTYARLCAYAPYNSVEFKNSSLKTEAGLTMQTYAAEKDMRFAVSGGDEVWKKTPTANFELKRAYARLVLSVVRDATYPNTCKITKAKIEAFTGNIITANTVDISTGTEGSGTQTPQYIHTVTTGLKDGFAIGLPQQTFSGGVVLTLTVDGMEYSVTIPANKLSTFVRGTKYIVSLAVKGGKLTLMSDKILIDKDWAEVQTGTGGSGDDYDTSFN\" https://api.esmatlas.com/foldSequence/v1/pdb/\n",
    "import requests\n",
    "import json\n",
    "url = 'https://api.esmatlas.com/foldSequence/v1/pdb/'\n",
    "r = requests.post(url, data=output_seq)\n",
    "output_struct = r.text\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e516f069",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize output structure\n",
    "view = py3Dmol.view(width=800, height=800)\n",
    "view.addModel(output_struct, 'pdb')\n",
    "view.setStyle({'cartoon': {'color': 'spectrum'}})\n",
    "view.zoomTo()\n",
    "view.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1c9a0f0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.12 ('esm-IF_py37')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "5502aca739f2549ad2771378ffc455b2bbb8b06f1a91617971f7097758a3cf84"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
